package classification.weibo_feeds

import org.ansj.splitWord.analysis.ToAnalysis
import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.feature.{HashingTF, IDF}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.SparkSession

import scala.collection.JavaConversions._

/**
  * Created by peibin on 2017/3/20.
  */
object WeiboFeedsKNN {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("WeiboFeeds")
      .master("local[*]")
      .getOrCreate()

    import spark.implicits._


    val feeds = spark.read.textFile("""file:///Users/peibin/workspace/data-mining/带有转发和情感标签的微博数据/data/*.txt""")
      .map(values => {
        val tuples = values.split("""\|\*\*\|""")
        try {
          tuples.length match {
            case 3 => Feeds(tuples(0).toInt, tuples(1), tuples(2).toInt, null)
            case 4 => Feeds(tuples(0).toInt, tuples(1), tuples(2).toInt, tuples(3).toInt)
          }
        } catch {
          case _ => Feeds(null, null, null, null)
        }
      }).filter(feeds => feeds.poster != null).rdd
    feeds.cache()

    val documents = feeds.map(f => {
      ToAnalysis.parse(f.feed).getTerms.map(_.getRealName)
    })

    val hashingTF = new HashingTF()
    val tf = hashingTF.transform(documents)
    tf.cache()

    // 滤除出现文档数小于2的词
    val idf = new IDF(minDocFreq = 2).fit(tf)
    val tfidf = idf.transform(tf)

    val data = tfidf.zip(feeds).map(p => {
      LabeledPoint(p._2.sentiment.doubleValue(), p._1)
    })

    val dataSet = data.randomSplit(Array(0.6, 0.4))
    val training = dataSet(0)
    val test = dataSet(1)


    val NBmodel = NaiveBayes.train(training, 1.0)

    val predictionAndLabel = test.map(t => {
      (NBmodel.predict(t.features), t.label)
    })

    val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()

    println(s"accuracy: ${accuracy}")

  }

  case class Feeds(poster: Integer, feed: String, sentiment: Integer, reTweetd: Integer);
}

class KNN {

}